{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "\n",
    "from argparse import Namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace(\n",
    "    raw_dataset_csv=\"data/surnames/surnames.csv\",\n",
    "    train_proportion=0.7,\n",
    "    val_proportion=0.15,\n",
    "    test_proportion=0.15,\n",
    "    output_munged_csv=\"data/surnames/surnames_with_splits.csv\",\n",
    "    seed=1337\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read raw data\n",
    "surnames = pd.read_csv(args.raw_dataset_csv, header=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>surname</th>\n",
       "      <th>nationality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Woodford</td>\n",
       "      <td>English</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Coté</td>\n",
       "      <td>French</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Kore</td>\n",
       "      <td>English</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Koury</td>\n",
       "      <td>Arabic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Lebzak</td>\n",
       "      <td>Russian</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    surname nationality\n",
       "0  Woodford     English\n",
       "1      Coté      French\n",
       "2      Kore     English\n",
       "3     Koury      Arabic\n",
       "4    Lebzak     Russian"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "surnames.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Arabic',\n",
       " 'Chinese',\n",
       " 'Czech',\n",
       " 'Dutch',\n",
       " 'English',\n",
       " 'French',\n",
       " 'German',\n",
       " 'Greek',\n",
       " 'Irish',\n",
       " 'Italian',\n",
       " 'Japanese',\n",
       " 'Korean',\n",
       " 'Polish',\n",
       " 'Portuguese',\n",
       " 'Russian',\n",
       " 'Scottish',\n",
       " 'Spanish',\n",
       " 'Vietnamese'}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Unique classes\n",
    "set(surnames.nationality)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Splitting train by nationality\n",
    "# Create dict\n",
    "by_nationality = collections.defaultdict(list)\n",
    "for _, row in surnames.iterrows():\n",
    "    by_nationality[row.nationality].append(row.to_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create split data\n",
    "final_list = []\n",
    "np.random.seed(args.seed)\n",
    "for _, item_list in sorted(by_nationality.items()):\n",
    "    np.random.shuffle(item_list)\n",
    "    n = len(item_list)\n",
    "    n_train = int(args.train_proportion*n)\n",
    "    n_val = int(args.val_proportion*n)\n",
    "    n_test = int(args.test_proportion*n)\n",
    "    \n",
    "    # Give data point a split attribute\n",
    "    for item in item_list[:n_train]:\n",
    "        item['split'] = 'train'\n",
    "    for item in item_list[n_train:n_train+n_val]:\n",
    "        item['split'] = 'val'\n",
    "    for item in item_list[n_train+n_val:]:\n",
    "        item['split'] = 'test'  \n",
    "    \n",
    "    # Add to final list\n",
    "    final_list.extend(item_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write split data to file\n",
    "final_surnames = pd.DataFrame(final_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "train    7680\n",
       "test     1660\n",
       "val      1640\n",
       "Name: split, dtype: int64"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_surnames.split.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>nationality</th>\n",
       "      <th>split</th>\n",
       "      <th>surname</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Arabic</td>\n",
       "      <td>train</td>\n",
       "      <td>Totah</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Arabic</td>\n",
       "      <td>train</td>\n",
       "      <td>Abboud</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Arabic</td>\n",
       "      <td>train</td>\n",
       "      <td>Fakhoury</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Arabic</td>\n",
       "      <td>train</td>\n",
       "      <td>Srour</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Arabic</td>\n",
       "      <td>train</td>\n",
       "      <td>Sayegh</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  nationality  split   surname\n",
       "0      Arabic  train     Totah\n",
       "1      Arabic  train    Abboud\n",
       "2      Arabic  train  Fakhoury\n",
       "3      Arabic  train     Srour\n",
       "4      Arabic  train    Sayegh"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_surnames.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write munged data to CSV\n",
    "final_surnames.to_csv(args.output_munged_csv, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlpbook",
   "language": "python",
   "name": "nlpbook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "12px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": "5",
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
